#!/usr/bin/python

import sys
import subprocess

proto = """
static long long
APSW_FaultInjectControl(const char *faultfunction, const char *filename, const char *funcname, int linenum, const char *args);

"""

call_pattern = """
({
    __auto_type _res_PySet_New = 0 ? PySet_New(__VA_ARGS__) : 0;

    _res_PySet_New = (typeof (_res_PySet_New))APSW_FaultInjectControl("PySet_New", __FILE__, __func__, __LINE__, #__VA_ARGS__);

    if ((typeof (_res_PySet_New))0x1FACADE == _res_PySet_New)
       _res_PySet_New = PySet_New(__VA_ARGS__);
    else if ((typeof(_res_PySet_New))0x2FACADE == _res_PySet_New)
    {
        PySet_New(__VA_ARGS__);
        _res_PySet_New = (typeof (_res_PySet_New))18;
    }
    _res_PySet_New;
})
"""

# The following APIs are not faulted
#
# PyWeakRef_GetRef - this is like a collection so failing results in the
# underlying ref being orphaned and leaked
#
# PyList_SetSlice - we use this to remove a list item which should always
# succeed


def get_definition(name, use_name):
    t = call_pattern.replace("PySet_New", use_name)
    if name != use_name:
        # put back pretty name in string passed to APSW_FaultInjectControl
        t = t.replace(f'"{ use_name }"', f'"{ name }"')
    t = t.strip().split("\n")
    maxlen = max(len(l) for l in t)
    for i in range(len(t) - 1):
        t[i] += " " * (maxlen - len(t[i])) + " \\\n"
    return "".join(t)


def genfile(symbols):
    res = []
    res.append(f"""\
/*  DO NOT EDIT THIS FILE
    This file is generated by tools/genfaultinject.py
    Edit that not this */
#ifdef APSW_FAULT_INJECT

#ifndef APSW_FAULT_INJECT_INCLUDED
{ proto }
#define APSW_FAULT_INJECT_INCLUDED
#endif

#ifdef APSW_FAULT_CLEAR
""")
    for s in sorted(symbols):
        res.append(f"#undef { s }")
    res.append("\n#else\n")
    for s in sorted(symbols):
        if s in call_map:
            res.append("#if PY_VERSION_HEX < 0x030d0000")
            res.append(f"#undef {s}")
            res.append(f"#define {s}(...) \\\n{ get_definition( s, call_map.get(s, s)) }")
            res.append("#else")
            res.append(f"#define {s}(...) \\\n{ get_definition( s, s) }")
            res.append("#endif")
        else:
            res.append(f"#define {s}(...) \\\n{ get_definition( s, s) }")
    res.append("#endif")
    res.append("#endif")
    return "\n".join(res)


returns = {
    # return a pointer, NULL on failure
    "pointer": """
            convert_value_to_pyobject convert_column_to_pyobject  allocfunccbinfo
            apsw_strdup convertutf8string MakeExistingException get_window_function_context
            MakeTableChange

            PyBool_FromLong PyBytes_FromStringAndSize PyCode_NewEmpty PyDict_New
            PyErr_NewExceptionWithDoc PyFloat_AsDouble PyFloat_FromDouble PyFloat_FromString
            PyFrame_New PyFrozenSet_New PyIter_Next PyList_GetItem PyList_New
            PyLong_FromLong PyLong_FromLongLong PyLong_FromSize_t
            PyLong_FromSsize_t PyLong_FromUnsignedLongLong PyLong_FromVoidPtr
            PyMapping_GetItemString PyMem_Calloc PyMem_Malloc PyMem_Realloc
            PyMemoryView_FromMemory  PyModule_Create2
            PyNumber_Float PyNumber_Long PyObject_CallMethodNoArgs
            PyObject_CallObject PyObject_GetAttr PyObject_GetIter PyObject_Str
            PyObject_Vectorcall PyObject_VectorcallMethod PySequence_Fast
            PySequence_GetItem PySequence_GetSlice PySequence_List
            PySequence_SetItem PySequence_Tuple PySet_New PyStructSequence_New
            PyStructSequence_NewType PyTuple_New PyTuple_Pack
            PyType_FromModuleAndSpec PyType_GenericNew PyUnicode_AsUTF8
            PyUnicode_AsUTF8AndSize PyUnicode_AsUTF8String PyUnicode_DecodeUTF8
            PyUnicode_FromFormat PyUnicode_FromKindAndData PyUnicode_FromString
            PyUnicode_FromStringAndSize  PyUnicode_New PyWeakref_GetObject PyWeakref_NewRef Py_BuildValue
            Py_VaBuildValue _PyObject_New

            realloc

            Connection_fts5_api get_token_value fts5extensionapi_acquire

            sqlite3_malloc sqlite3_malloc64 sqlite3_mprintf
            sqlite3_realloc sqlite3_realloc64
            sqlite3_normalized_sql sqlite3_expanded_sql
            """.split(),
    # numeric return
    "sqlite": """
            sqlite3_aggregate_context sqlite3_autovacuum_pages
            sqlite3_backup_finish sqlite3_backup_init
            sqlite3_backup_step sqlite3_bind_blob sqlite3_bind_blob64
            sqlite3_bind_double sqlite3_bind_int sqlite3_bind_int64
            sqlite3_carray_bind sqlite3_carray_bind_apsw
            sqlite3_bind_null sqlite3_bind_pointer sqlite3_bind_text
            sqlite3_bind_text64 sqlite3_bind_value
            sqlite3_bind_zeroblob sqlite3_bind_zeroblob64
            sqlite3_blob_open sqlite3_blob_read sqlite3_blob_reopen
            sqlite3_blob_write sqlite3_busy_handler
            sqlite3_busy_timeout
            sqlite3_clear_bindings sqlite3_close sqlite3_close_v2
            sqlite3_collation_needed sqlite3_column_name
            sqlite3_complete sqlite3_config sqlite3_create_collation
            sqlite3_create_collation_v2 sqlite3_create_function
            sqlite3_create_function_v2 sqlite3_create_module
            sqlite3_create_module_v2 sqlite3_create_window_function
            sqlite3_db_cacheflush sqlite3_db_config sqlite3_db_status
            sqlite3_declare_vtab sqlite3_deserialize
            sqlite3_drop_modules sqlite3_enable_load_extension
            sqlite3_exec
            sqlite3_initialize
            sqlite3_load_extension
            sqlite3_open
            sqlite3_open_v2 sqlite3_overload_function
            sqlite3_prepare_v3
            sqlite3_result_zeroblob64
            sqlite3_set_authorizer sqlite3_shutdown sqlite3_status64
            sqlite3_table_column_metadata sqlite3_threadsafe
            sqlite3_trace_v2 sqlite3_vfs_register
            sqlite3_vfs_unregister sqlite3_vtab_config
            sqlite3_vtab_in_next sqlite3_vtab_rhs_value
            sqlite3_wal_autocheckpoint sqlite3_wal_checkpoint_v2

            sqlite3_preupdate_old sqlite3_preupdate_new

            sqlite3_prepare

            sqlite3changeset_apply_v2_strm sqlite3changeset_concat_strm
            sqlite3changeset_invert_strm sqlite3changeset_start_strm sqlite3changeset_start_v2_strm
            sqlite3session_changeset_strm sqlite3session_patchset_strm sqlite3changegroup_add_strm
            sqlite3changegroup_output_strm sqlite3rebaser_rebase_strm

            sqlite3changegroup_add sqlite3changegroup_add_change sqlite3changegroup_new sqlite3changegroup_output
            sqlite3changegroup_schema sqlite3changeset_apply_v2
            sqlite3changeset_concat sqlite3changeset_conflict sqlite3changeset_fk_conflicts
            sqlite3changeset_invert sqlite3changeset_new sqlite3changeset_next sqlite3changeset_old
            sqlite3changeset_op sqlite3changeset_pk sqlite3changeset_start sqlite3changeset_start_v2
            sqlite3changeset_upgrade sqlite3session_attach sqlite3session_changeset
            sqlite3session_config sqlite3session_create sqlite3session_diff
            sqlite3session_object_config sqlite3session_patchset
            sqlite3rebaser_configure sqlite3rebaser_create
            sqlite3rebaser_rebase

            """.split(),
    # py functions that return a number (-1) to indicate failure
    "number": """Py_EnterRecursiveCall
        PyType_Ready PyModule_AddObject PyModule_AddIntConstant PyModule_AddStringConstant
        PyLong_AsLong  PyLong_AsLongLong PyList_Append PyDict_SetItemString
        PyObject_SetAttr _PyBytes_Resize PyDict_SetItem
        PyObject_IsTrue PySequence_Size PySet_Add PyObject_IsTrueStrict
        PyStructSequence_InitType2 PyList_Size PyLong_AsInt
        PyList_SetItem PyList_Sort PySet_Contains PyObject_IsInstance
        PyMapping_Size PySet_Discard

        PyObject_GetBufferContiguous PyObject_GetBuffer PyObject_GetBufferContiguousBounded
        _PyTuple_Resize

        getfunctionargs

        jsonb_grow_buffer jsonb_add_tag jsonb_update_tag jsonb_append_data
        jsonb_add_tag_and_data jsonb_encode_internal jsonb_encode_object_key
        """.split(),
}

# some calls like Py_BuildValue are #defined to _Py_BuildValue_SizeT
# so deal with that here
call_map = {
    "Py_BuildValue": "_Py_BuildValue_SizeT",
    "PyArg_ParseTuple": "_PyArg_ParseTuple_SizeT",
    "Py_VaBuildValue": "_Py_VaBuildValue_SizeT",
}

# double check no dupes
for k, v in returns.items():
    if len(set(v)) != len(v):
        seen = set()
        for val in v:
            if val in seen:
                print(f"Duplicate item { val } in { k }")
                sys.exit(1)
            else:
                seen.add(val)

# these don't provide meaning for fault injection
no_error = set(
    """PyBuffer_Release PyDict_GetItem PyMem_Free PyDict_GetItemString PyErr_Clear
    PyErr_Display PyErr_Fetch PyErr_Format PyErr_NoMemory PyErr_NormalizeException
    PyErr_Occurred PyErr_Print PyErr_Restore PyErr_SetObject PyEval_RestoreThread
    PyEval_SaveThread PyGILState_Ensure PyGILState_Release PyOS_snprintf
    PyObject_CheckBuffer PyObject_ClearWeakRefs PyObject_GC_UnTrack PyObject_HasAttr
    PyThreadState_Get PyThread_get_thread_ident PyTraceBack_Here
    PyType_IsSubtype PyUnicode_CopyCharacters  _Py_Dealloc
    _Py_HashBytes _Py_NegativeRefcount _Py_RefTotal PyThreadState_GetFrame
""".split()
)

# these could error but are only used in a small number of places where
# errors are already dealt with
no_error.update(
    """PyArg_ParseTuple PyBytes_AsString PyErr_GivenExceptionMatches PyFrame_GetBack
    PyImport_ImportModule PyLong_AsLongAndOverflow PyLong_AsVoidPtr
    PyObject_IsInstance PySys_GetObject PyErr_ExceptionMatches
    PyErr_SetString PyStructSequence_SetItem PyObject_Print
    Py_GetRecursionLimit Py_LeaveRecursiveCall Py_SetRecursionLimit _PyErr_ChainExceptions

""".split()
)


def check_dll(fname, all):
    not_seen = set()
    for line in subprocess.run(["nm", "-u", fname], text=True, capture_output=True, check=True).stdout.split("\n"):
        if not line.strip().startswith("U") or "@" in line or "Py" not in line:
            continue
        _, sym = line.split()
        if sym in all:
            assert sym not in no_error, f"{ sym } in all and no_error"

        if sym in call_map.values():
            for k, v in call_map.items():
                if sym == v:
                    sym = k
                    break
            else:
                1 / 0

        if (
            sym in all
            or sym in no_error
            or sym.endswith("_Check")
            or sym.endswith("_Type")
            or sym.endswith("Struct")
            or sym.startswith("PyExc_")
        ):
            continue

        not_seen.add(sym)

    print(sorted(not_seen))
    print(len(not_seen), "items")


if __name__ == "__main__":
    all = set()
    for v in returns.values():
        all.update(v)
    if sys.argv[1].endswith(".h"):
        r = genfile(all)
        open(sys.argv[1], "wt").write(r)
    else:
        check_dll(sys.argv[1], all)
